{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a742b29e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:08:11.010991Z",
     "start_time": "2023-11-09T12:08:11.007039Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86120810",
   "metadata": {},
   "source": [
    "##  einsum review"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "367db35d",
   "metadata": {},
   "source": [
    "\n",
    "> 'ik,kj->ij'\n",
    "\n",
    "- free indices: specified in the output (-> 之后/右边的)\n",
    "- summation indices: all other (也就是没有出现在 -> 右边的，但出现在 -> 左边的 index 都会被求和 reduce 掉)\n",
    "\n",
    "\n",
    "$$\n",
    "(A\\cdot B)_{ij} = \\sum_k A_{ik}\\cdot B_{kj}\\\\\n",
    "AB_{ij}=A_{ik}B_{kj}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d338dbb7",
   "metadata": {},
   "source": [
    "## bilinear"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3d69329",
   "metadata": {},
   "source": [
    "$$\n",
    "y=x_1^TAx_2+b\n",
    "$$\n",
    "\n",
    "- $x_1\\in\\mathbb R^{d_1}$: 列向量（数学中默认）\n",
    "- $x_2\\in\\mathbb R^{d_2}$：列向量（数学中默认）\n",
    "- $A\\in \\mathbb R^{d_1\\times d_2}$\n",
    "- $y$ 是一个 scalar"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bffb821b",
   "metadata": {},
   "source": [
    "## `torch.nn.functional.bilinear`/`F.bilinear`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c57b80e2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:09:22.509095Z",
     "start_time": "2023-11-09T12:09:22.494124Z"
    }
   },
   "outputs": [],
   "source": [
    "x_1 = torch.randn((3, 1))\n",
    "x_2 = torch.randn((2, 1))\n",
    "A = torch.randn((3, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4a59dcc1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:09:42.368499Z",
     "start_time": "2023-11-09T12:09:42.333871Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.7900]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.mm(x_1.T, torch.mm(A, x_2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f2e56e03",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:12:15.423360Z",
     "start_time": "2023-11-09T12:12:15.402851Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.7900]]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.bilinear(x_1.T.unsqueeze(0), x_2.T.unsqueeze(0), A.unsqueeze(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c1f5ed6",
   "metadata": {},
   "source": [
    "## `torch.nn.Bilinear`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "1cb2e895",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:44:39.181375Z",
     "start_time": "2023-11-09T12:44:39.171625Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([40, 20, 30])\n",
      "torch.Size([40])\n"
     ]
    }
   ],
   "source": [
    "m = nn.Bilinear(20, 30, 40)\n",
    "# out * in1 * in2\n",
    "print(m.weight.shape)\n",
    "print(m.bias.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "8db7cb7f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:44:40.250471Z",
     "start_time": "2023-11-09T12:44:40.243742Z"
    }
   },
   "outputs": [],
   "source": [
    "# 128 pairs (x1, x2)\n",
    "# x1: (20, 1)^T\n",
    "# x2: (30, 1)^T\n",
    "input1 = torch.randn(128, 20)\n",
    "input2 = torch.randn(128, 30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "e8a7693f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:44:41.931126Z",
     "start_time": "2023-11-09T12:44:41.922995Z"
    }
   },
   "outputs": [],
   "source": [
    "output = m(input1, input2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "7ef7371a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:44:43.008841Z",
     "start_time": "2023-11-09T12:44:42.998953Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 40])"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "2ef73d9e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:44:43.979031Z",
     "start_time": "2023-11-09T12:44:43.966634Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5529, -1.8147, -1.4468,  ...,  2.4288,  0.1620, -1.5111],\n",
       "        [-2.1974, -4.0457, -0.4051,  ...,  1.7315,  4.8727,  5.2919],\n",
       "        [-1.4074,  1.8243,  0.7437,  ..., -0.7086, -0.5513,  0.8196],\n",
       "        ...,\n",
       "        [-2.5626,  5.0811, -3.8342,  ...,  4.6002,  0.7150,  1.7936],\n",
       "        [-0.9222,  1.4171,  1.4422,  ...,  3.3310,  1.9861, -3.7736],\n",
       "        [-2.5675,  6.2278,  0.9713,  ...,  2.5801, -2.5891,  0.3585]],\n",
       "       grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "a9ff2085",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:44:45.926258Z",
     "start_time": "2023-11-09T12:44:45.543618Z"
    }
   },
   "outputs": [],
   "source": [
    "Y = torch.zeros(128, 40)\n",
    "# (x1, x2)\n",
    "# b (free index)\n",
    "for i in range(128):\n",
    "    # A\n",
    "    # a (free index)\n",
    "    for j in range(40):\n",
    "        Y[i, j] = input1[i, :].view(1, -1) @ (m.weight[j, :, :] @ input2[i, :].view(-1, 1)) + m.bias[j]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "ae3f4d2a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:44:46.760885Z",
     "start_time": "2023-11-09T12:44:46.752565Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5529, -1.8147, -1.4468,  ...,  2.4288,  0.1620, -1.5111],\n",
       "        [-2.1974, -4.0457, -0.4051,  ...,  1.7315,  4.8727,  5.2919],\n",
       "        [-1.4074,  1.8243,  0.7437,  ..., -0.7086, -0.5513,  0.8196],\n",
       "        ...,\n",
       "        [-2.5626,  5.0811, -3.8342,  ...,  4.6002,  0.7150,  1.7936],\n",
       "        [-0.9222,  1.4171,  1.4422,  ...,  3.3310,  1.9861, -3.7736],\n",
       "        [-2.5675,  6.2278,  0.9713,  ...,  2.5801, -2.5891,  0.3585]],\n",
       "       grad_fn=<CopySlices>)"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "79748b0f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:44:49.451437Z",
     "start_time": "2023-11-09T12:44:49.437734Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5529, -1.8147, -1.4468,  ...,  2.4288,  0.1620, -1.5111],\n",
       "        [-2.1974, -4.0457, -0.4051,  ...,  1.7315,  4.8727,  5.2919],\n",
       "        [-1.4074,  1.8243,  0.7437,  ..., -0.7086, -0.5513,  0.8196],\n",
       "        ...,\n",
       "        [-2.5626,  5.0811, -3.8342,  ...,  4.6002,  0.7150,  1.7936],\n",
       "        [-0.9222,  1.4171,  1.4422,  ...,  3.3310,  1.9861, -3.7736],\n",
       "        [-2.5675,  6.2278,  0.9713,  ...,  2.5801, -2.5891,  0.3585]],\n",
       "       grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# input1: b, n (128, 20)\n",
    "# input2: b, m (128, 30)\n",
    "# A: a, n, m (40, 20, 30)\n",
    "torch.einsum('bn,abn->ba', input1, torch.einsum('anm,bm->abn', m.weight, input2)) + m.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "e30e2776",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-09T12:44:51.327580Z",
     "start_time": "2023-11-09T12:44:51.314860Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5529, -1.8147, -1.4468,  ...,  2.4288,  0.1620, -1.5111],\n",
       "        [-2.1974, -4.0457, -0.4051,  ...,  1.7315,  4.8727,  5.2919],\n",
       "        [-1.4074,  1.8243,  0.7437,  ..., -0.7086, -0.5513,  0.8196],\n",
       "        ...,\n",
       "        [-2.5626,  5.0811, -3.8342,  ...,  4.6002,  0.7150,  1.7936],\n",
       "        [-0.9222,  1.4171,  1.4422,  ...,  3.3310,  1.9861, -3.7736],\n",
       "        [-2.5675,  6.2278,  0.9713,  ...,  2.5801, -2.5891,  0.3585]],\n",
       "       grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.einsum('bn,anm,bm->ba', input1, m.weight, input2) + m.bias"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
